{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded\n",
      "Files already downloaded\n"
     ]
    }
   ],
   "source": [
    "# MNIST CNN classifier \n",
    "# Code by GunhoChoi\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils as utils\n",
    "from torch.autograd import Variable\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "# Set Hyperparameters\n",
    "\n",
    "epoch = 100\n",
    "batch_size =16\n",
    "learning_rate = 0.001\n",
    "\n",
    "# Download Data\n",
    "\n",
    "mnist_train = dset.MNIST(\"./\", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)\n",
    "mnist_test  = dset.MNIST(\"./\", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "60000\n",
      "10000\n",
      "torch.Size([1, 28, 28]) 5\n",
      "torch.Size([1, 28, 28]) 7\n"
     ]
    }
   ],
   "source": [
    "# Check the datasets downloaded\n",
    "\n",
    "print(mnist_train.__len__())\n",
    "print(mnist_test.__len__())\n",
    "img1,label1 = mnist_train.__getitem__(0)\n",
    "img2,label2 = mnist_test.__getitem__(0)\n",
    "\n",
    "print(img1.size(), label1)\n",
    "print(img2.size(), label2)\n",
    "\n",
    "# Set Data Loader(input pipeline)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(dataset=mnist_train,batch_size=batch_size,shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(dataset=mnist_test,batch_size=batch_size,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, \n",
    "#                 padding=0, dilation=1, groups=1, bias=True)\n",
    "# torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1,\n",
    "#                    return_indices=False, ceil_mode=False)\n",
    "# torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1,affine=True)\n",
    "# torch.nn.ReLU()\n",
    "# tensor.view(newshape)\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN,self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "                        nn.Conv2d(1,16,5),   # batch x 16 x 24 x 24\n",
    "                        nn.ReLU(),\n",
    "                        nn.BatchNorm2d(16),\n",
    "                        nn.Conv2d(16,32,5),  # batch x 32 x 20 x 20\n",
    "                        nn.ReLU(),\n",
    "                        nn.BatchNorm2d(32),\n",
    "                        nn.MaxPool2d(2,2)   # batch x 32 x 10 x 10\n",
    "        )\n",
    "        self.layer2 = nn.Sequential(\n",
    "                        nn.Conv2d(32,64,5),  # batch x 64 x 6 x 6\n",
    "                        nn.ReLU(),\n",
    "                        nn.BatchNorm2d(64),\n",
    "                        nn.Conv2d(64,128,5),  # batch x 128 x 2 x 2\n",
    "                        nn.ReLU()\n",
    "        )\n",
    "        self.fc = nn.Linear(2*2*128,10)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)\n",
    "        out = out.view(batch_size, -1)\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "        \n",
    "cnn = CNN()\n",
    "\n",
    "loss_func = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      " 2.3027\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  8.0201\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  3.7111\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  1.8575\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.3383\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  1.5311\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.7637\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  9.3253\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  9.1775\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  3.4686\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  1.0231\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  2.4441\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  3.0596\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.2678\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-04 *\n",
      "  2.1912\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.7921\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  1.0485\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  1.2817\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  7.0363\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.2775\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.5993\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  2.0506\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  6.7827\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  5.5896\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  3.1059\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.1066\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.1235\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  3.6109\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  1.8422\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.3572\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  2.6084\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.3257\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      " 0.1525\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  1.0409\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-04 *\n",
      "  3.9107\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-03 *\n",
      "  1.8482\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  8.0254\n",
      "[torch.FloatTensor of size 1]\n",
      "\n",
      "Variable containing:\n",
      "1.00000e-02 *\n",
      "  1.0144\n",
      "[torch.FloatTensor of size 1]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Train Model with train data\n",
    "# In order to use GPU you need to move all Variables and model by Module.cuda()\n",
    "\n",
    "for i in range(epoch):\n",
    "    for j,[image,label] in enumerate(train_loader):\n",
    "        image = Variable(image)\n",
    "        label = Variable(label)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        result = cnn.forward(image)\n",
    "        loss = loss_func(result,label)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if j % 100 == 0:\n",
    "            print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of Test Data: 0.9864\n"
     ]
    }
   ],
   "source": [
    "# Test with test data\n",
    "# In order test, we need to change model mode to .eval()\n",
    "# and get the highest score label for accuracy\n",
    "\n",
    "cnn.eval()\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "for image,label in test_loader:\n",
    "    image = Variable(image)\n",
    "    result = cnn(image)\n",
    "    \n",
    "    _,pred = torch.max(result.data,1)\n",
    "    \n",
    "    total += label.size(0)\n",
    "    correct += (pred == label).sum()\n",
    "    \n",
    "print(\"Accuracy of Test Data: {}\".format(correct/total))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
